from typing import Any, Mapping, Sequence, Union, Tuple

# import jax

# PRNGKey = jax.random.KeyArray
# # PyTree = Union[jax.typing.ArrayLike, Mapping[str, "PyTree"]]
Config = Union[Any, Mapping[str, "Config"]]
# Params = Mapping[str, PyTree]
# Data = Mapping[str, PyTree]
# Shape = Sequence[int]
# Dtype = jax.typing.DTypeLike
